{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0hBMTLEWzVro"
      },
      "source": [
        "# Knowledge Distillation with Tunix: Gemma 7B to Gemma 2B\n",
        "\n",
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://kaggle.com/kernels/welcome?src=https://github.com/google/tunix/blob/main/examples/logit_distillation.ipynb\"><img src=\"https://www.kaggle.com/static/images/logos/kaggle-logo-transparent-300.png\" height=\"32\" width=\"70\"/>Run in Kaggle</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/google/tunix/blob/main/examples/logit_distillation.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8hGFhPDbBOSr"
      },
      "source": [
        "## Install necessary libraries\n",
        "\n",
        "This notebook demonstrates how to use the **Tunix** library to perform knowledge distillation. Specifically, we will use **logit-based distillation** to transfer knowledge from a larger, more capable **teacher model (Gemma 7B)** to a smaller, more efficient **student model (Gemma 2B)**.\n",
        "\n",
        "## What is Knowledge Distillation?\n",
        "Knowledge distillation is a model compression technique where a smaller \"student\" model is trained to mimic the behavior of a larger, pre-trained \"teacher\" model. Instead of training the student solely on the ground-truth labels, we also train it to replicate the teacher's outputs.\n",
        "\n",
        "## Logit-Based Distillation\n",
        "In this specific strategy, the student model learns to match the teacher's **logits** (the raw, unnormalized outputs before the final softmax layer). By doing so, the student learns the nuanced probability distribution that the teacher model has learned, which is often more informative than the hard labels alone.\n",
        "\n",
        "The core components we'll use are:\n",
        "-   **Teacher Model**: Gemma 7B\n",
        "-   **Student Model**: Gemma 2B\n",
        "-   **Distillation Strategy**: `tunix.distillation.strategies.LogitStrategy`\n",
        "-   **Trainer**: `tunix.distillation.DistillationTrainer`\n",
        "\n",
        "In this tutorial, we use a v5e-8 TPU. Let's get started!"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8-VI2wbWzY6b"
      },
      "outputs": [],
      "source": [
        "!pip install -q kagglehub\n",
        "\n",
        "!pip install -q tensorflow\n",
        "!pip install -q tensorboardX\n",
        "!pip install -q grain-nightly\n",
        "!pip install -q datasets\n",
        "!pip install -q git+https://github.com/google/tunix\n",
        "!pip install -q git+https://github.com/google/qwix\n",
        "\n",
        "!pip uninstall -q -y flax\n",
        "!pip install -q git+https://github.com/google/flax.git"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "c1xPRviCzdgg"
      },
      "outputs": [],
      "source": [
        "import gc\n",
        "import os\n",
        "\n",
        "from flax import nnx\n",
        "import jax\n",
        "import jax.numpy as jnp\n",
        "import kagglehub\n",
        "import optax\n",
        "from orbax import checkpoint as ocp\n",
        "from tunix.distillation import distillation_trainer\n",
        "from tunix.distillation import strategies\n",
        "from tunix.generate import sampler as sampler_lib\n",
        "from tunix.generate import tokenizer_adapter as tokenizer_lib\n",
        "from tunix.models.gemma import model as gemma_lib\n",
        "from tunix.models.gemma import params as params_lib\n",
        "from tunix.examples.data import translation_dataset as data_lib"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "u0xLXuslziAm"
      },
      "source": [
        "## Utility Function to check HBM"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cNgOJM4Yzma7"
      },
      "outputs": [],
      "source": [
        "import functools\n",
        "import humanize\n",
        "from tunix.sft import utils\n",
        "\n",
        "show_hbm_usage = utils.show_hbm_usage\n",
        "show_hbm_usage()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KPVVoafkzUhe"
      },
      "outputs": [],
      "source": [
        "# --- Data ---\n",
        "BATCH_SIZE = 4\n",
        "MAX_TARGET_LENGTH = 128\n",
        "NUM_TRAIN_EPOCHS = 1\n",
        "\n",
        "# --- Model ---\n",
        "MESH = [(1, 8), (\"fsdp\", \"tp\")]\n",
        "\n",
        "# --- Training ---\n",
        "MAX_STEPS = 200\n",
        "EVAL_EVERY_N_STEPS = 50\n",
        "LEARNING_RATE = 1e-4\n",
        "\n",
        "# --- Distillation ---\n",
        "TEMPERATURE = 2.0  # Softens the teacher's probabilities\n",
        "ALPHA = 0.7  # Balances distillation loss and student's own task loss\n",
        "\n",
        "# --- Checkpointing ---\n",
        "TEACHER_CKPT_DIR = \"/content/intermediate_ckpt/teacher/\"\n",
        "STUDENT_CKPT_DIR = \"/content/intermediate_ckpt/student/\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QyHsZZ7hzuBl"
      },
      "source": [
        "First, we need to load our teacher and student models. We'll use Gemma 7B as the teacher and Gemma 2B as the student.\n",
        "\n",
        "**Important**: You must have a Kaggle account and agree to the Gemma license to download the models. The first time you run this, you will be prompted to log in to Kaggle."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "h1iQ5hzuzrbR"
      },
      "outputs": [],
      "source": [
        "# Log in to Kaggle\n",
        "if \"KAGGLE_USERNAME\" not in os.environ or \"KAGGLE_KEY\" not in os.environ:\n",
        "  kagglehub.login()\n",
        "\n",
        "\n",
        "def load_and_save_model(model_handle, version, ckpt_dir):\n",
        "  \"\"\"Loads a model from Kaggle, saves it locally, and cleans up memory.\"\"\"\n",
        "  print(f\"Loading {model_handle}...\")\n",
        "  kaggle_ckpt_path = kagglehub.model_download(model_handle)\n",
        "  ckpt_version = \"2b-it\"\n",
        "  if \"7b\" in version:\n",
        "    ckpt_version = \"7b-it\"\n",
        "  # Temporarily set the default device to CPU for loading the full model\n",
        "  with jax.default_device(jax.devices(\"cpu\")[0]):\n",
        "    params = params_lib.load_and_format_params(\n",
        "        os.path.join(kaggle_ckpt_path, ckpt_version)\n",
        "    )\n",
        "    gemma = gemma_lib.Gemma.from_params(params, version=version)\n",
        "\n",
        "  print(f\"Saving checkpoint to {ckpt_dir}...\")\n",
        "  checkpointer = ocp.StandardCheckpointer()\n",
        "  _, state = nnx.split(gemma)\n",
        "  checkpointer.save(os.path.join(ckpt_dir, \"state\"), state)\n",
        "  checkpointer.wait_until_finished()\n",
        "  # Clean up to save memory\n",
        "  del params\n",
        "  del gemma\n",
        "  del state\n",
        "  gc.collect()\n",
        "  print(f\"Finished processing {model_handle}.\")\n",
        "\n",
        "\n",
        "# Load Teacher Model (Gemma 7B)\n",
        "load_and_save_model(\n",
        "    \"google/gemma/flax/1.1-7b-it\", \"1.1-7b-it\", TEACHER_CKPT_DIR\n",
        ")\n",
        "\n",
        "# Load Student Model (Gemma 2B)\n",
        "load_and_save_model(\n",
        "    \"google/gemma/flax/1.1-2b-it\", \"1.1-2b-it\", STUDENT_CKPT_DIR\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ALVsADIVz4Iw"
      },
      "source": [
        "Now that we have the checkpoints saved locally, we can load them into sharded models. Sharding is essential for training large models efficiently on TPUs by distributing the model's weights and the computation across multiple devices."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZNC5c8glz3Mu"
      },
      "outputs": [],
      "source": [
        "def get_sharded_model(ckpt_path, model_config, mesh):\n",
        "  \"\"\"Loads a checkpoint into a sharded model.\"\"\"\n",
        "  abs_gemma: nnx.Module = nnx.eval_shape(\n",
        "      lambda: gemma_lib.Gemma(model_config, rngs=nnx.Rngs(params=0))\n",
        "  )\n",
        "  abs_state = nnx.state(abs_gemma)\n",
        "  abs_state = jax.tree.map(\n",
        "      lambda a, s: jax.ShapeDtypeStruct(a.shape, jnp.bfloat16, sharding=s),\n",
        "      abs_state,\n",
        "      nnx.get_named_sharding(abs_state, mesh),\n",
        "  )\n",
        "  checkpointer = ocp.StandardCheckpointer()\n",
        "  restored_params = checkpointer.restore(ckpt_path, target=abs_state)\n",
        "\n",
        "  graph_def, _ = nnx.split(abs_gemma)\n",
        "  gemma = nnx.merge(graph_def, restored_params)\n",
        "  return gemma\n",
        "\n",
        "\n",
        "mesh = jax.make_mesh(*MESH, axis_types=(jax.sharding.AxisType.Auto,) * len(MESH[0]))\n",
        "\n",
        "# Create Teacher Model\n",
        "print(\"Creating sharded teacher model (Gemma 7B)...\")\n",
        "teacher_config = gemma_lib.ModelConfig.gemma_7b()\n",
        "teacher_model = get_sharded_model(\n",
        "    os.path.join(TEACHER_CKPT_DIR, \"state\"), teacher_config, mesh\n",
        ")\n",
        "print(\"Teacher model created.\")\n",
        "# nnx.display(teacher_model) # Optional: view model structure\n",
        "\n",
        "# Create Student Model\n",
        "print(\"\\nCreating sharded student model (Gemma 2B)...\")\n",
        "student_config = gemma_lib.ModelConfig.gemma_2b()\n",
        "student_model = get_sharded_model(\n",
        "    os.path.join(STUDENT_CKPT_DIR, \"state\"), student_config, mesh\n",
        ")\n",
        "print(\"Student model created.\")\n",
        "# nnx.display(student_model) # Optional: view model structure\n",
        "\n",
        "show_hbm_usage()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ixCv3zQO0CeD"
      },
      "outputs": [],
      "source": [
        "print(\"Loading tokenizer...\")\n",
        "gemma_tokenizer_path = os.path.join(\n",
        "    kagglehub.model_download(\"google/gemma/flax/1.1-2b-it\"), \"tokenizer.model\"\n",
        ")\n",
        "gemma_tokenizer = tokenizer_lib.Tokenizer(\n",
        "    tokenizer_type='sentencepiece', \n",
        "    tokenizer_path=gemma_tokenizer_path\n",
        ")\n",
        "print(\"Tokenizer loaded.\")\n",
        "\n",
        "print(\"\\nCreating datasets...\")\n",
        "train_ds, validation_ds = data_lib.create_datasets(\n",
        "    dataset_name=\"mtnt/en-fr\",\n",
        "    global_batch_size=BATCH_SIZE,\n",
        "    max_target_length=MAX_TARGET_LENGTH,\n",
        "    num_train_epochs=NUM_TRAIN_EPOCHS,\n",
        "    tokenizer=gemma_tokenizer,\n",
        "    instruct_tuned=True,\n",
        ")\n",
        "print(\"Datasets created.\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qHdFCIT00H80"
      },
      "source": [
        "The `LogitStrategy` requires three key functions:\n",
        "1.  `model_forward_fn`: A function that performs a forward pass for a given model and returns its logits. Since both our models are from the Gemma family, we can use the same function for both.\n",
        "2.  `labels_fn`: A function that creates the ground-truth labels from the input data for the standard cross-entropy loss.\n",
        "3.  `gen_model_input_fn`: A helper function to format each batch from the data loader into the dictionary format expected by the model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cEez-UGO0Isd"
      },
      "outputs": [],
      "source": [
        "VOCAB_SIZE = student_config.num_embed\n",
        "\n",
        "\n",
        "def model_forward_fn(\n",
        "    model: nnx.Module,\n",
        "    input_tokens: jax.Array,\n",
        "    input_mask: jax.Array,\n",
        "    positions: jax.Array,\n",
        "    attention_mask: jax.Array,\n",
        "):\n",
        "  \"\"\"Performs a forward pass and returns the logits.\"\"\"\n",
        "  logits, _ = model(\n",
        "      input_tokens,\n",
        "      positions,\n",
        "      None,\n",
        "      attention_mask,\n",
        "  )\n",
        "  # Exclude the last step as it does not appear in the targets.\n",
        "  return logits[:, :-1, :]\n",
        "\n",
        "\n",
        "def labels_fn(\n",
        "    input_tokens: jax.Array,\n",
        "    input_mask: jax.Array,\n",
        "    **kwargs,\n",
        "):\n",
        "  \"\"\"Creates one-hot encoded labels for the next-token prediction task.\"\"\"\n",
        "  target_tokens = input_tokens[:, 1:]\n",
        "  target_mask = input_mask[:, 1:]\n",
        "  labels = jax.nn.one_hot(target_tokens, VOCAB_SIZE)\n",
        "  # Mask out the padding tokens from the loss calculation.\n",
        "  return labels * target_mask.astype(labels.dtype)[..., None]\n",
        "\n",
        "\n",
        "def gen_model_input_fn(x: distillation_trainer.TrainingInput):\n",
        "  \"\"\"Formats a batch from the data loader into the model's expected input format.\"\"\"\n",
        "  pad_mask = x.input_tokens != gemma_tokenizer.pad_id()\n",
        "  positions = utils.build_positions_from_mask(pad_mask)\n",
        "  attention_mask = utils.make_causal_attn_mask(pad_mask)\n",
        "  return {\n",
        "      'input_tokens': x.input_tokens,\n",
        "      'input_mask': x.input_mask,\n",
        "      'positions': positions,\n",
        "      'attention_mask': attention_mask,\n",
        "  }"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "t43x9Fla0S0X"
      },
      "source": [
        "Now we can assemble all the components. We'll instantiate the `LogitStrategy`, configure the `DistillationTrainer`, and start the training process. The trainer will handle the distributed training loop across the available TPU cores."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XBugUhCl0VwK"
      },
      "outputs": [],
      "source": [
        "# 1. Setup the distillation strategy\n",
        "strategy = strategies.LogitStrategy(\n",
        "    student_forward_fn=model_forward_fn,\n",
        "    teacher_forward_fn=model_forward_fn,\n",
        "    labels_fn=labels_fn,\n",
        "    temperature=TEMPERATURE,\n",
        "    alpha=ALPHA,\n",
        ")\n",
        "\n",
        "# 2. Setup the training configuration\n",
        "config = distillation_trainer.TrainingConfig(\n",
        "    eval_every_n_steps=EVAL_EVERY_N_STEPS,\n",
        "    max_steps=MAX_STEPS,\n",
        ")\n",
        "\n",
        "# 3. Setup the optimizer\n",
        "optimizer = optax.adamw(LEARNING_RATE)\n",
        "\n",
        "\n",
        "# Set teacher model in eval mode\n",
        "teacher_model.eval()\n",
        "# Set student model in train mode\n",
        "student_model.train()\n",
        "# 4. Instantiate the trainer\n",
        "trainer = distillation_trainer.DistillationTrainer(\n",
        "    student_model=student_model,\n",
        "    teacher_model=teacher_model,\n",
        "    strategy=strategy,\n",
        "    optimizer=optimizer,\n",
        "    training_config=config,\n",
        ").with_gen_model_input_fn(gen_model_input_fn)\n",
        "\n",
        "# 5. Run training within the mesh context, the first couple of training step might take up to 5 minutes to finish. Please be patient. If you experience long training steps, e.g. >10 minutes per, please open a bug. Really appreciated!\n",
        "print(\"Starting distillation training...\")\n",
        "with mesh:\n",
        "  trainer.train(train_ds, validation_ds)\n",
        "print(\"Training complete.\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Z1_1VkkB0ev_"
      },
      "source": [
        "After training, the student model should have improved its ability to perform the translation task by learning from the teacher. Let's test it with a few sample prompts."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Rn_V4rw70fXL"
      },
      "outputs": [],
      "source": [
        "print(\"Setting up sampler for evaluation...\")\n",
        "sampler = sampler_lib.Sampler(\n",
        "    transformer=student_model,\n",
        "    tokenizer=gemma_tokenizer,\n",
        "    cache_config=sampler_lib.CacheConfig(\n",
        "        cache_size=MAX_TARGET_LENGTH + 64,\n",
        "        num_layers=student_config.num_layers,\n",
        "        num_kv_heads=student_config.num_kv_heads,\n",
        "        head_dim=student_config.head_dim,\n",
        "    ),\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "13ozlAZP0jl4"
      },
      "outputs": [],
      "source": [
        "input_batch = [\n",
        "    \"Translate this into French:\\nHello, my name is Morgane.\\n\",\n",
        "    \"Translate this into French:\\nThis dish is delicious!\\n\",\n",
        "    \"Translate this into French:\\nI am a student.\\n\",\n",
        "]\n",
        "\n",
        "print(\"Generating translations with the distilled student model...\")\n",
        "with mesh:\n",
        "  out_data = sampler(\n",
        "      input_strings=input_batch,\n",
        "      max_generation_steps=20,\n",
        "  )\n",
        "\n",
        "print(\"\\n--- Evaluation Results ---\")\n",
        "for input_string, out_string in zip(input_batch, out_data.text):\n",
        "  print(f\"----------------------\")\n",
        "  print(f\"Prompt:\\n{input_string}\")\n",
        "  print(f\"Distilled Student's Output:\\n{out_string}\")"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "last_runtime": {
        "build_target": "//learning/grp/tools/ml_python:ml_notebook",
        "kind": "private"
      },
      "private_outputs": true,
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
